{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-19T03:00:30.966738700Z",
     "start_time": "2024-07-19T03:00:25.100849100Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=12, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.9.0\n",
      "numpy 1.26.4\n",
      "pandas 2.2.2\n",
      "sklearn 1.5.0\n",
      "torch 2.3.1+cpu\n",
      "cpu\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "    \n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-19T03:00:44.105190300Z",
     "start_time": "2024-07-19T03:00:43.981163600Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".. _california_housing_dataset:\n",
      "\n",
      "California Housing dataset\n",
      "--------------------------\n",
      "\n",
      "**Data Set Characteristics:**\n",
      "\n",
      ":Number of Instances: 20640\n",
      "\n",
      ":Number of Attributes: 8 numeric, predictive attributes and the target\n",
      "\n",
      ":Attribute Information:\n",
      "    - MedInc        median income in block group\n",
      "    - HouseAge      median house age in block group\n",
      "    - AveRooms      average number of rooms per household\n",
      "    - AveBedrms     average number of bedrooms per household\n",
      "    - Population    block group population\n",
      "    - AveOccup      average number of household members\n",
      "    - Latitude      block group latitude\n",
      "    - Longitude     block group longitude\n",
      "\n",
      ":Missing Attribute Values: None\n",
      "\n",
      "This dataset was obtained from the StatLib repository.\n",
      "https://www.dcc.fc.up.pt/~ltorgo/Regression/cal_housing.html\n",
      "\n",
      "The target variable is the median house value for California districts,\n",
      "expressed in hundreds of thousands of dollars ($100,000).\n",
      "\n",
      "This dataset was derived from the 1990 U.S. census, using one row per census\n",
      "block group. A block group is the smallest geographical unit for which the U.S.\n",
      "Census Bureau publishes sample data (a block group typically has a population\n",
      "of 600 to 3,000 people).\n",
      "\n",
      "A household is a group of people residing within a home. Since the average\n",
      "number of rooms and bedrooms in this dataset are provided per household, these\n",
      "columns may take surprisingly large values for block groups with few households\n",
      "and many empty houses, such as vacation resorts.\n",
      "\n",
      "It can be downloaded/loaded using the\n",
      ":func:`sklearn.datasets.fetch_california_housing` function.\n",
      "\n",
      ".. topic:: References\n",
      "\n",
      "    - Pace, R. Kelley and Ronald Barry, Sparse Spatial Autoregressions,\n",
      "      Statistics and Probability Letters, 33 (1997) 291-297\n",
      "\n",
      "(20640, 8)\n",
      "(20640,)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.datasets import fetch_california_housing\n",
    "\n",
    "housing = fetch_california_housing()\n",
    "print(housing.DESCR)\n",
    "print(housing.data.shape)\n",
    "print(housing.target.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-25T02:58:48.363080700Z",
     "start_time": "2024-04-25T02:58:48.322104900Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "array([[ 8.32520000e+00,  4.10000000e+01,  6.98412698e+00,\n",
      "         1.02380952e+00,  3.22000000e+02,  2.55555556e+00,\n",
      "         3.78800000e+01, -1.22230000e+02],\n",
      "       [ 8.30140000e+00,  2.10000000e+01,  6.23813708e+00,\n",
      "         9.71880492e-01,  2.40100000e+03,  2.10984183e+00,\n",
      "         3.78600000e+01, -1.22220000e+02],\n",
      "       [ 7.25740000e+00,  5.20000000e+01,  8.28813559e+00,\n",
      "         1.07344633e+00,  4.96000000e+02,  2.80225989e+00,\n",
      "         3.78500000e+01, -1.22240000e+02],\n",
      "       [ 5.64310000e+00,  5.20000000e+01,  5.81735160e+00,\n",
      "         1.07305936e+00,  5.58000000e+02,  2.54794521e+00,\n",
      "         3.78500000e+01, -1.22250000e+02],\n",
      "       [ 3.84620000e+00,  5.20000000e+01,  6.28185328e+00,\n",
      "         1.08108108e+00,  5.65000000e+02,  2.18146718e+00,\n",
      "         3.78500000e+01, -1.22250000e+02]])\n",
      "--------------------------------------------------\n",
      "array([4.526, 3.585, 3.521, 3.413, 3.422])\n"
     ]
    }
   ],
   "source": [
    "# print(housing.data[0:5])\n",
    "import pprint  #打印的格式比较 好看\n",
    "\n",
    "pprint.pprint(housing.data[0:5])\n",
    "print('-'*50)\n",
    "pprint.pprint(housing.target[0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-19T03:00:48.864497400Z",
     "start_time": "2024-07-19T03:00:48.793909200Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(11610, 8) (11610,)\n",
      "(3870, 8) (3870,)\n",
      "(5160, 8) (5160,)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "#拆分训练集和测试集，random_state是随机种子,同样的随机数种子，是为了得到同样的随机值\n",
    "x_train_all, x_test, y_train_all, y_test = train_test_split(\n",
    "    housing.data, housing.target, random_state = 7)\n",
    "x_train, x_valid, y_train, y_valid = train_test_split(\n",
    "    x_train_all, y_train_all, random_state = 11)\n",
    "# 训练集\n",
    "print(x_train.shape, y_train.shape)\n",
    "# 验证集\n",
    "print(x_valid.shape, y_valid.shape)\n",
    "# 测试集\n",
    "print(x_test.shape, y_test.shape)\n",
    "\n",
    "dataset_maps = {\n",
    "    \"train\": [x_train, y_train],\n",
    "    \"valid\": [x_valid, y_valid],\n",
    "    \"test\": [x_test, y_test],\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-19T03:00:51.860770300Z",
     "start_time": "2024-07-19T03:00:51.848298900Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "StandardScaler()",
      "text/html": "<style>#sk-container-id-1 {\n  /* Definition of color scheme common for light and dark mode */\n  --sklearn-color-text: black;\n  --sklearn-color-line: gray;\n  /* Definition of color scheme for unfitted estimators */\n  --sklearn-color-unfitted-level-0: #fff5e6;\n  --sklearn-color-unfitted-level-1: #f6e4d2;\n  --sklearn-color-unfitted-level-2: #ffe0b3;\n  --sklearn-color-unfitted-level-3: chocolate;\n  /* Definition of color scheme for fitted estimators */\n  --sklearn-color-fitted-level-0: #f0f8ff;\n  --sklearn-color-fitted-level-1: #d4ebff;\n  --sklearn-color-fitted-level-2: #b3dbfd;\n  --sklearn-color-fitted-level-3: cornflowerblue;\n\n  /* Specific color for light theme */\n  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n  --sklearn-color-icon: #696969;\n\n  @media (prefers-color-scheme: dark) {\n    /* Redefinition of color scheme for dark theme */\n    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n    --sklearn-color-icon: #878787;\n  }\n}\n\n#sk-container-id-1 {\n  color: var(--sklearn-color-text);\n}\n\n#sk-container-id-1 pre {\n  padding: 0;\n}\n\n#sk-container-id-1 input.sk-hidden--visually {\n  border: 0;\n  clip: rect(1px 1px 1px 1px);\n  clip: rect(1px, 1px, 1px, 1px);\n  height: 1px;\n  margin: -1px;\n  overflow: hidden;\n  padding: 0;\n  position: absolute;\n  width: 1px;\n}\n\n#sk-container-id-1 div.sk-dashed-wrapped {\n  border: 1px dashed var(--sklearn-color-line);\n  margin: 0 0.4em 0.5em 0.4em;\n  box-sizing: border-box;\n  padding-bottom: 0.4em;\n  background-color: var(--sklearn-color-background);\n}\n\n#sk-container-id-1 div.sk-container {\n  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n     but bootstrap.min.css set `[hidden] { display: none !important; }`\n     so we also need the `!important` here to be able to override the\n     default hidden behavior on the sphinx rendered scikit-learn.org.\n     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n  display: inline-block !important;\n  position: relative;\n}\n\n#sk-container-id-1 div.sk-text-repr-fallback {\n  display: none;\n}\n\ndiv.sk-parallel-item,\ndiv.sk-serial,\ndiv.sk-item {\n  /* draw centered vertical line to link estimators */\n  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n  background-size: 2px 100%;\n  background-repeat: no-repeat;\n  background-position: center center;\n}\n\n/* Parallel-specific style estimator block */\n\n#sk-container-id-1 div.sk-parallel-item::after {\n  content: \"\";\n  width: 100%;\n  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n  flex-grow: 1;\n}\n\n#sk-container-id-1 div.sk-parallel {\n  display: flex;\n  align-items: stretch;\n  justify-content: center;\n  background-color: var(--sklearn-color-background);\n  position: relative;\n}\n\n#sk-container-id-1 div.sk-parallel-item {\n  display: flex;\n  flex-direction: column;\n}\n\n#sk-container-id-1 div.sk-parallel-item:first-child::after {\n  align-self: flex-end;\n  width: 50%;\n}\n\n#sk-container-id-1 div.sk-parallel-item:last-child::after {\n  align-self: flex-start;\n  width: 50%;\n}\n\n#sk-container-id-1 div.sk-parallel-item:only-child::after {\n  width: 0;\n}\n\n/* Serial-specific style estimator block */\n\n#sk-container-id-1 div.sk-serial {\n  display: flex;\n  flex-direction: column;\n  align-items: center;\n  background-color: var(--sklearn-color-background);\n  padding-right: 1em;\n  padding-left: 1em;\n}\n\n\n/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\nclickable and can be expanded/collapsed.\n- Pipeline and ColumnTransformer use this feature and define the default style\n- Estimators will overwrite some part of the style using the `sk-estimator` class\n*/\n\n/* Pipeline and ColumnTransformer style (default) */\n\n#sk-container-id-1 div.sk-toggleable {\n  /* Default theme specific background. It is overwritten whether we have a\n  specific estimator or a Pipeline/ColumnTransformer */\n  background-color: var(--sklearn-color-background);\n}\n\n/* Toggleable label */\n#sk-container-id-1 label.sk-toggleable__label {\n  cursor: pointer;\n  display: block;\n  width: 100%;\n  margin-bottom: 0;\n  padding: 0.5em;\n  box-sizing: border-box;\n  text-align: center;\n}\n\n#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n  /* Arrow on the left of the label */\n  content: \"▸\";\n  float: left;\n  margin-right: 0.25em;\n  color: var(--sklearn-color-icon);\n}\n\n#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n  color: var(--sklearn-color-text);\n}\n\n/* Toggleable content - dropdown */\n\n#sk-container-id-1 div.sk-toggleable__content {\n  max-height: 0;\n  max-width: 0;\n  overflow: hidden;\n  text-align: left;\n  /* unfitted */\n  background-color: var(--sklearn-color-unfitted-level-0);\n}\n\n#sk-container-id-1 div.sk-toggleable__content.fitted {\n  /* fitted */\n  background-color: var(--sklearn-color-fitted-level-0);\n}\n\n#sk-container-id-1 div.sk-toggleable__content pre {\n  margin: 0.2em;\n  border-radius: 0.25em;\n  color: var(--sklearn-color-text);\n  /* unfitted */\n  background-color: var(--sklearn-color-unfitted-level-0);\n}\n\n#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n  /* unfitted */\n  background-color: var(--sklearn-color-fitted-level-0);\n}\n\n#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n  /* Expand drop-down */\n  max-height: 200px;\n  max-width: 100%;\n  overflow: auto;\n}\n\n#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n  content: \"▾\";\n}\n\n/* Pipeline/ColumnTransformer-specific style */\n\n#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n  color: var(--sklearn-color-text);\n  background-color: var(--sklearn-color-unfitted-level-2);\n}\n\n#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n  background-color: var(--sklearn-color-fitted-level-2);\n}\n\n/* Estimator-specific style */\n\n/* Colorize estimator box */\n#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n  /* unfitted */\n  background-color: var(--sklearn-color-unfitted-level-2);\n}\n\n#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n  /* fitted */\n  background-color: var(--sklearn-color-fitted-level-2);\n}\n\n#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n#sk-container-id-1 div.sk-label label {\n  /* The background is the default theme color */\n  color: var(--sklearn-color-text-on-default-background);\n}\n\n/* On hover, darken the color of the background */\n#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n  color: var(--sklearn-color-text);\n  background-color: var(--sklearn-color-unfitted-level-2);\n}\n\n/* Label box, darken color on hover, fitted */\n#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n  color: var(--sklearn-color-text);\n  background-color: var(--sklearn-color-fitted-level-2);\n}\n\n/* Estimator label */\n\n#sk-container-id-1 div.sk-label label {\n  font-family: monospace;\n  font-weight: bold;\n  display: inline-block;\n  line-height: 1.2em;\n}\n\n#sk-container-id-1 div.sk-label-container {\n  text-align: center;\n}\n\n/* Estimator-specific */\n#sk-container-id-1 div.sk-estimator {\n  font-family: monospace;\n  border: 1px dotted var(--sklearn-color-border-box);\n  border-radius: 0.25em;\n  box-sizing: border-box;\n  margin-bottom: 0.5em;\n  /* unfitted */\n  background-color: var(--sklearn-color-unfitted-level-0);\n}\n\n#sk-container-id-1 div.sk-estimator.fitted {\n  /* fitted */\n  background-color: var(--sklearn-color-fitted-level-0);\n}\n\n/* on hover */\n#sk-container-id-1 div.sk-estimator:hover {\n  /* unfitted */\n  background-color: var(--sklearn-color-unfitted-level-2);\n}\n\n#sk-container-id-1 div.sk-estimator.fitted:hover {\n  /* fitted */\n  background-color: var(--sklearn-color-fitted-level-2);\n}\n\n/* Specification for estimator info (e.g. \"i\" and \"?\") */\n\n/* Common style for \"i\" and \"?\" */\n\n.sk-estimator-doc-link,\na:link.sk-estimator-doc-link,\na:visited.sk-estimator-doc-link {\n  float: right;\n  font-size: smaller;\n  line-height: 1em;\n  font-family: monospace;\n  background-color: var(--sklearn-color-background);\n  border-radius: 1em;\n  height: 1em;\n  width: 1em;\n  text-decoration: none !important;\n  margin-left: 1ex;\n  /* unfitted */\n  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n  color: var(--sklearn-color-unfitted-level-1);\n}\n\n.sk-estimator-doc-link.fitted,\na:link.sk-estimator-doc-link.fitted,\na:visited.sk-estimator-doc-link.fitted {\n  /* fitted */\n  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n  color: var(--sklearn-color-fitted-level-1);\n}\n\n/* On hover */\ndiv.sk-estimator:hover .sk-estimator-doc-link:hover,\n.sk-estimator-doc-link:hover,\ndiv.sk-label-container:hover .sk-estimator-doc-link:hover,\n.sk-estimator-doc-link:hover {\n  /* unfitted */\n  background-color: var(--sklearn-color-unfitted-level-3);\n  color: var(--sklearn-color-background);\n  text-decoration: none;\n}\n\ndiv.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n.sk-estimator-doc-link.fitted:hover,\ndiv.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n.sk-estimator-doc-link.fitted:hover {\n  /* fitted */\n  background-color: var(--sklearn-color-fitted-level-3);\n  color: var(--sklearn-color-background);\n  text-decoration: none;\n}\n\n/* Span, style for the box shown on hovering the info icon */\n.sk-estimator-doc-link span {\n  display: none;\n  z-index: 9999;\n  position: relative;\n  font-weight: normal;\n  right: .2ex;\n  padding: .5ex;\n  margin: .5ex;\n  width: min-content;\n  min-width: 20ex;\n  max-width: 50ex;\n  color: var(--sklearn-color-text);\n  box-shadow: 2pt 2pt 4pt #999;\n  /* unfitted */\n  background: var(--sklearn-color-unfitted-level-0);\n  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n}\n\n.sk-estimator-doc-link.fitted span {\n  /* fitted */\n  background: var(--sklearn-color-fitted-level-0);\n  border: var(--sklearn-color-fitted-level-3);\n}\n\n.sk-estimator-doc-link:hover span {\n  display: block;\n}\n\n/* \"?\"-specific style due to the `<a>` HTML tag */\n\n#sk-container-id-1 a.estimator_doc_link {\n  float: right;\n  font-size: 1rem;\n  line-height: 1em;\n  font-family: monospace;\n  background-color: var(--sklearn-color-background);\n  border-radius: 1rem;\n  height: 1rem;\n  width: 1rem;\n  text-decoration: none;\n  /* unfitted */\n  color: var(--sklearn-color-unfitted-level-1);\n  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n}\n\n#sk-container-id-1 a.estimator_doc_link.fitted {\n  /* fitted */\n  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n  color: var(--sklearn-color-fitted-level-1);\n}\n\n/* On hover */\n#sk-container-id-1 a.estimator_doc_link:hover {\n  /* unfitted */\n  background-color: var(--sklearn-color-unfitted-level-3);\n  color: var(--sklearn-color-background);\n  text-decoration: none;\n}\n\n#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n  /* fitted */\n  background-color: var(--sklearn-color-fitted-level-3);\n}\n</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>StandardScaler()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;StandardScaler<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.preprocessing.StandardScaler.html\">?<span>Documentation for StandardScaler</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>StandardScaler()</pre></div> </div></div></div></div>"
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "scaler = StandardScaler()\n",
    "scaler.fit(x_train)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构建数据集\n",
    "\n",
    "这里我们构建多输入的数据集，注意到数据集介绍里对每个特征定义如下：\n",
    "\n",
    "> **Attribute Information**:\n",
    "> - MedInc        median income in block group\n",
    "> - HouseAge      median house age in block group\n",
    "> - AveRooms      average number of rooms per household\n",
    "> - AveBedrms     average number of bedrooms per household\n",
    "> - Population    block group population\n",
    "> - AveOccup      average number of household members\n",
    "> - Latitude      block group latitude\n",
    "> - Longitude     block group longitude\n",
    "\n",
    "我们认为最后两维作为位置信息要单独处理，故制作数据集如下"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-19T03:02:28.127928Z",
     "start_time": "2024-07-19T03:02:28.116426700Z"
    }
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "\n",
    "class HousingDataset(Dataset):\n",
    "    def __init__(self, mode='train'):\n",
    "        self.x, self.y = dataset_maps[mode]\n",
    "        self.x = torch.from_numpy(scaler.transform(self.x)).float()\n",
    "        self.y = torch.from_numpy(self.y).float().reshape(-1, 1)\n",
    "            \n",
    "    def __len__(self):\n",
    "        return len(self.x)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return (self.x[idx],self.x[idx][-2:]), self.y[idx] #返回的是一个元组，元组里面是两个元素，第一个元素是一个元组，第二个元素是一个数。self.x[idx][-2:]代表取最后两个元素\n",
    "    \n",
    "    \n",
    "train_ds = HousingDataset(\"train\")\n",
    "valid_ds = HousingDataset(\"valid\")\n",
    "test_ds = HousingDataset(\"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-25T02:59:52.379477100Z",
     "start_time": "2024-04-25T02:59:52.339501300Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "((tensor([-0.2981,  0.3523, -0.1092, -0.2506, -0.0341, -0.0060,  1.0806, -1.0611]),\n  tensor([ 1.0806, -1.0611])),\n tensor([1.5140]))"
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_ds[1]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-19T03:02:31.008413200Z",
     "start_time": "2024-07-19T03:02:31.000640600Z"
    }
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "batch_size = 8\n",
    "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)\n",
    "val_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False)\n",
    "test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-19T03:04:00.543191800Z",
     "start_time": "2024-07-19T03:04:00.529712200Z"
    }
   },
   "outputs": [],
   "source": [
    "#回归模型我们只需要1个数\n",
    "\n",
    "class WideDeep(nn.Module):\n",
    "    def __init__(self, input_dim=[8, 2]):\n",
    "        super().__init__()\n",
    "        self.deep = nn.Sequential(\n",
    "            nn.Linear(input_dim[1], 30),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(30, 30),\n",
    "            nn.ReLU()\n",
    "            )\n",
    "        # pytorch 需要自行计算输出输出维度\n",
    "        self.output_layer = nn.Linear(30 + input_dim[0], 1)\n",
    "        \n",
    "        # 初始化权重\n",
    "        self.init_weights()\n",
    "        \n",
    "    def init_weights(self):\n",
    "        \"\"\"使用 xavier 均匀分布来初始化全连接层的权重 W\"\"\"\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Linear):\n",
    "                nn.init.xavier_uniform_(m.weight)\n",
    "                nn.init.zeros_(m.bias)\n",
    "        \n",
    "    def forward(self, x_wide, x_deep):\n",
    "        # x_deep.shape [batch size, 6]\n",
    "        deep_output = self.deep(x_deep)\n",
    "        # concat [batch size, 30] with [batch size 8]\n",
    "        concat = torch.cat([x_wide, deep_output], dim=1)\n",
    "        logits = self.output_layer(concat)\n",
    "        # logits.shape [batch size, 1]\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-19T03:04:06.694276400Z",
     "start_time": "2024-07-19T03:04:06.677291500Z"
    }
   },
   "outputs": [],
   "source": [
    "class EarlyStopCallback:\n",
    "    def __init__(self, patience=5, min_delta=0.01):\n",
    "        \"\"\"\n",
    "\n",
    "        Args:\n",
    "            patience (int, optional): Number of epochs with no improvement after which training will be stopped.. Defaults to 5.\n",
    "            min_delta (float, optional): Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute \n",
    "                change of less than min_delta, will count as no improvement. Defaults to 0.01.\n",
    "        \"\"\"\n",
    "        self.patience = patience\n",
    "        self.min_delta = min_delta\n",
    "        self.best_metric = -1\n",
    "        self.counter = 0\n",
    "        \n",
    "    def __call__(self, metric):\n",
    "        if metric >= self.best_metric + self.min_delta:\n",
    "            # update best metric\n",
    "            self.best_metric = metric\n",
    "            # reset counter \n",
    "            self.counter = 0\n",
    "        else: \n",
    "            self.counter += 1\n",
    "            \n",
    "    @property\n",
    "    def early_stop(self):\n",
    "        return self.counter >= self.patience\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [
    {
     "data": {
      "text/plain": "WideDeep(\n  (deep): Sequential(\n    (0): Linear(in_features=2, out_features=30, bias=True)\n    (1): ReLU()\n    (2): Linear(in_features=30, out_features=30, bias=True)\n    (3): ReLU()\n  )\n  (output_layer): Linear(in_features=38, out_features=1, bias=True)\n)"
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = WideDeep()\n",
    "model"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-19T03:04:09.126087500Z",
     "start_time": "2024-07-19T03:04:09.112095600Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "deep.0.weight torch.Size([30, 2])\n",
      "deep.0.bias torch.Size([30])\n",
      "deep.2.weight torch.Size([30, 30])\n",
      "deep.2.bias torch.Size([30])\n",
      "output_layer.weight torch.Size([1, 38])\n",
      "output_layer.bias torch.Size([1])\n"
     ]
    }
   ],
   "source": [
    "for name, param in model.named_parameters():\n",
    "      print(name, param.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-25T03:07:17.865667600Z",
     "start_time": "2024-04-25T03:07:17.843767800Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-19T03:04:18.726974100Z",
     "start_time": "2024-07-19T03:04:18.712294800Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluating(model, dataloader, loss_fct):\n",
    "    loss_list = []\n",
    "    for (datas_deep, datas_wide), labels in dataloader:\n",
    "        datas_deep = datas_deep.to(device)\n",
    "        datas_wide = datas_wide.to(device)\n",
    "        labels = labels.to(device)\n",
    "        # 前向计算\n",
    "        logits = model(datas_deep, datas_wide)\n",
    "        loss = loss_fct(logits, labels)         # 验证集损失\n",
    "        loss_list.append(loss.item())\n",
    "        \n",
    "    return np.mean(loss_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-19T03:06:11.746963700Z",
     "start_time": "2024-07-19T03:05:35.793150100Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "  0%|          | 0/14520 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "66529da2188247d4ae5f963c026a1f06"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 训练\n",
    "def training(\n",
    "    model, \n",
    "    train_loader, \n",
    "    val_loader, \n",
    "    epoch, \n",
    "    loss_fct, \n",
    "    optimizer, \n",
    "    tensorboard_callback=None,\n",
    "    save_ckpt_callback=None,\n",
    "    early_stop_callback=None,\n",
    "    eval_step=500,\n",
    "    ):\n",
    "    record_dict = {\n",
    "        \"train\": [],\n",
    "        \"val\": []\n",
    "    }\n",
    "    \n",
    "    global_step = 0\n",
    "    model.train()\n",
    "    with tqdm(total=epoch * len(train_loader)) as pbar:\n",
    "        for epoch_id in range(epoch):\n",
    "            # training\n",
    "            for (datas_deep, datas_wide), labels in train_loader:\n",
    "                datas_deep = datas_deep.to(device)\n",
    "                datas_wide = datas_wide.to(device)\n",
    "                labels = labels.to(device)\n",
    "                # 梯度清空\n",
    "                optimizer.zero_grad()\n",
    "                # 模型前向计算\n",
    "                logits = model(datas_deep, datas_wide)\n",
    "                # 计算损失\n",
    "                loss = loss_fct(logits, labels)\n",
    "                # 梯度回传\n",
    "                loss.backward()\n",
    "                # 调整优化器，包括学习率的变动等\n",
    "                optimizer.step()\n",
    " \n",
    "                loss = loss.cpu().item()\n",
    "                # record\n",
    "                \n",
    "                record_dict[\"train\"].append({\n",
    "                    \"loss\": loss, \"step\": global_step\n",
    "                })\n",
    "                \n",
    "                # evaluating\n",
    "                if global_step % eval_step == 0:\n",
    "                    model.eval()\n",
    "                    val_loss = evaluating(model, val_loader, loss_fct)\n",
    "                    record_dict[\"val\"].append({\n",
    "                        \"loss\": val_loss, \"step\": global_step\n",
    "                    })\n",
    "                    model.train()\n",
    "\n",
    "                    # 早停 Early Stop\n",
    "                    if early_stop_callback is not None:\n",
    "                        early_stop_callback(-val_loss)\n",
    "                        if early_stop_callback.early_stop:\n",
    "                            print(f\"Early stop at epoch {epoch_id} / global_step {global_step}\")\n",
    "                            return record_dict\n",
    "                    \n",
    "                # udate step\n",
    "                global_step += 1\n",
    "                pbar.update(1)\n",
    "                pbar.set_postfix({\"epoch\": epoch_id})\n",
    "        \n",
    "    return record_dict\n",
    "        \n",
    "\n",
    "epoch = 10\n",
    "\n",
    "\n",
    "\n",
    "# 1. 定义损失函数 采用MSE损失\n",
    "loss_fct = nn.MSELoss()\n",
    "# 2. 定义优化器 采用SGD\n",
    "# Optimizers specified in the torch.optim package\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.0)\n",
    "\n",
    "# 3. early stop\n",
    "early_stop_callback = EarlyStopCallback(patience=5, min_delta=1e-3)\n",
    "\n",
    "model = model.to(device)\n",
    "record = training(\n",
    "    model, \n",
    "    train_loader, \n",
    "    val_loader, \n",
    "    epoch, \n",
    "    loss_fct, \n",
    "    optimizer, \n",
    "    early_stop_callback=early_stop_callback,\n",
    "    eval_step=len(train_loader)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#画线要注意的是损失是不一定在零到1之间的\n",
    "def plot_learning_curves(record_dict, sample_step=500):\n",
    "    # build DataFrame\n",
    "    train_df = pd.DataFrame(record_dict[\"train\"]).set_index(\"step\").iloc[::sample_step]\n",
    "    val_df = pd.DataFrame(record_dict[\"val\"]).set_index(\"step\")\n",
    "\n",
    "    # plot\n",
    "    for idx, item in enumerate(train_df.columns):\n",
    "        plt.plot(train_df.index, train_df[item], label=f\"train_{item}\")\n",
    "        plt.plot(val_df.index, val_df[item], label=f\"val_{item}\")\n",
    "        plt.grid()\n",
    "        plt.legend()\n",
    "        # plt.xticks(range(0, train_df.index[-1], 10*sample_step), range(0, train_df.index[-1], 10*sample_step))\n",
    "        plt.xlabel(\"step\")\n",
    "\n",
    "        plt.show()\n",
    "\n",
    "plot_learning_curves(record, sample_step=500)  #横坐标是 steps"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "loss = evaluating(model, val_loader, loss_fct)\n",
    "print(f\"loss:     {loss:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
